import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from ResNet import resnet50  # 从自定义的ResNet.py文件中导入resnet50这个函数
import numpy as np
import matplotlib.pyplot as plt
from utils import read_split_data, train_one_epoch, evaluate
from my_dataset import MyDataSet

# -------------------------------------------------- #
# （0）参数设置
# -------------------------------------------------- #
batch_size = 256   # 每个step训练32张图片
epochs = 100  # 共训练10次

# -------------------------------------------------- #
# （1）文件配置
# -------------------------------------------------- #
# 数据集文件夹位置
filepath = r'C:\Users\w8887757\PycharmProjects\wm811k\data'
# 权重文件位置
weightpath = './input/pretrained_weights/resnet50.pth'
# 权重保存文件夹路径
savepath = './input/save_weights/'

# 获取GPU设备
if torch.cuda.is_available():  # 如果有GPU就用，没有就用CPU
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

# -------------------------------------------------- #
# （2）构造数据集
# -------------------------------------------------- #
# 训练集的数据预处理
transform_train = transforms.Compose([
    # # 数据增强，随机裁剪224*224大小
    # transforms.RandomResizedCrop(224),
    # transforms.Resize((224, 224)),
    # transforms.Resize((32, 32)),
    transforms.Resize((26, 26)),
    transforms.RandomRotation(45), #随机旋转， -45到45度之间随机选择
    transforms.RandomHorizontalFlip(p=0.5), #随机水平翻转，选择一个概率
    transforms.RandomVerticalFlip(p=0.5), #随机垂直翻转
    # 数据变成tensor类型，像素值归一化，调整维度[h,w,c]==>[c,h,w]
    transforms.ToTensor(),
    # 对每个通道的像素进行标准化，给出每个通道的均值和方差
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

# 验证集的数据预处理
transform_val = transforms.Compose([
    # 将输入图像大小调整为224*224
    # transforms.Resize((224, 224)),
    # transforms.Resize((32, 32)),
    transforms.Resize((26, 26)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

# # 读取训练集并预处理
# train_dataset = datasets.ImageFolder(root=filepath + 'train',  # 训练集图片所在的文件夹
#                                      transform=transform_train)  # 训练集的预处理方法
#
# # 读取验证集并预处理
# val_dataset = datasets.ImageFolder(root=filepath + 'val',  # 验证集图片所在的文件夹
#                                    transform=transform_val)  # 验证集的预处理方法

train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(filepath + r'\train')
# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_images_path,
                          images_class=train_images_label,
                          transform=transform_train)

# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_images_path,
                        images_class=val_images_label,
                        transform=transform_val)
# 查看训练集和验证集的图片数量
train_num = len(train_dataset)
val_num = len(val_dataset)
print('train_num:', train_num, 'val_num:', val_num)  # 453, 112

# 查看图像类别及其对应的索引
class_dict = {'Center':0,'Donut':1,'Edge-Loc':2,'Edge-Ring':3,'Loc':4,'Near-full':5,'Random':6,'Scratch':7}
print(class_dict)
# 将类别名称保存在列表中
class_names = list(class_dict.keys())

# 构造训练集
train_loader = DataLoader(dataset=train_dataset,  # 接收训练集
                          batch_size=batch_size,  # 训练时每个step处理32张图
                          shuffle=True,  # 打乱每个batch
                          num_workers=0)  # 加载数据时的线程数量，windows环境下只能=0

# 构造验证集
val_loader = DataLoader(dataset=val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=0)

# -------------------------------------------------- #
# （3）数据可视化
# -------------------------------------------------- #
# 取出一个batch的训练集，返回图片及其标签
train_img, train_label = next(iter(train_loader))
# 查看shape, img=[32,3,224,224], label=[32]
print(train_img.shape, train_label.shape)

# 从一个batch中取出前9张图片
img = train_img[:9]  # [9, 3, 224, 224]
# 将图片反标准化，像素变到0-1之间
img = img / 2 + 0.5
# tensor类型变成numpy类型
img = img.numpy()
class_label = train_label.numpy()
# 维度重排 [b,c,h,w]==>[b,h,w,c]
img = np.transpose(img, [0, 2, 3, 1])

# # 创建画板
# plt.figure()
# # 绘制四张图片
# for i in range(img.shape[0]):
#     plt.subplot(3, 3, i + 1)
#     plt.imshow(img[i])
#     plt.xticks([])  # 不显示x轴刻度
#     plt.yticks([])  # 不显示y轴刻度
#     plt.title(class_names[class_label[i]])  # 图片对应的类别
#
# plt.tight_layout()  # 轻量化布局
# plt.show()

# -------------------------------------------------- #
# （4）加载模型
# -------------------------------------------------- #
# 8分类层
net = resnet50(num_classes=8, include_top=True)

# 加载预训练权重
pre_weights = torch.load(weightpath)
net.to(device)
# net.load_state_dict(pre_weights)
# net.load_state_dict(torch.load(weightpath, map_location=device))

# 为网络重写分类层
in_channel = net.fc.in_features  # 2048
net.fc = nn.Linear(in_channel, 8)  # [b,2048]==>[b,4]

# 将模型搬运到GPU上
net.to(device)
# 定义交叉熵损失
loss_function = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.Adam(net.parameters(), lr=0.001)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# 保存准确率最高的一次迭代
best_acc = 0.0
losses_train = []
accuracy_train = []
losses_val = []
accuracy_val = []
# -------------------------------------------------- #
# （5）网络训练
# -------------------------------------------------- #
for epoch in range(epochs):

    print('-' * 30, '\n', 'epoch:', epoch)

    # 将模型设置为训练模型, dropout层和BN层只在训练时起作用
    net.train()

    # 计算训练一个epoch的总损失
    running_loss = 0.0
    train_correct = 0.0
    train_total = 0.0
    # scheduler.step()
    # 每个step训练一个batch
    for step, data in enumerate(train_loader):
        # data中包含图像及其对应的标签
        images, labels = data

        # 梯度清零，因为每次计算梯度是一个累加
        optimizer.zero_grad()

        # 前向传播
        outputs = net(images.to(device))

        # 计算预测值和真实值的交叉熵损失
        loss = loss_function(outputs, labels.to(device))
        predicted = torch.argmax(outputs, 1)
        train_correct += (predicted == labels.to(device)).sum().item()
        train_total += labels.size(0)
        # 梯度计算
        loss.backward()

        # 权重更新
        optimizer.step()

        # 累加每个step的损失
        running_loss += loss.item()

        # # 打印每个step的损失
        # print(f'step:{step} loss:{loss}')
    train_loss = running_loss / step
    train_accuracy = train_correct / train_total
    losses_train.append(train_loss)
    accuracy_train.append(train_accuracy)
    # -------------------------------------------------- #
    # （6）网络验证
    # -------------------------------------------------- #
    net.eval()  # 切换为验证模型，BN和Dropout不起作用

    acc = 0.0  # 验证集准确率
    valid_loss = 0.0
    with torch.no_grad():  # 下面不进行梯度计算

        # 每次验证一个batch
        for data_test in val_loader:
            # 获取验证集的图片和标签
            test_images, test_labels = data_test

            # 前向传播
            outputs = net(test_images.to(device))

            # 预测分数的最大值
            predict_y = torch.max(outputs, dim=1)[1]
            loss_val = loss_function(outputs, test_labels.to(device))
            valid_loss += loss_val.item()
            # 累加每个step的准确率
            acc += (predict_y == test_labels.to(device)).sum().item()

        # 计算所有图片的平均准确率
        acc_test = acc / val_num

        # 打印每个epoch的训练损失和验证准确率
        print(f'total_train_loss:{running_loss / step}, total_test_acc:{acc_test}')

        # -------------------------------------------------- #
        # （7）权重保存
        # -------------------------------------------------- #
        # 保存最好的准确率的权重
        if acc_test > best_acc:
            # 更新最佳的准确率
            best_acc = acc_test
            # 保存的权重名称
            savename = savepath + 'resnet50_dataAug_epo100_class8.pth'
            # 保存当前权重
            torch.save(net.state_dict(), savename)
        losses_val.append(valid_loss/len(val_loader))
        accuracy_val.append(acc_test)
print(best_acc)
epochs = np.arange(1, epochs+1)

# 绘制loss曲线图
plt.figure()
plt.plot(epochs, losses_train, label='Loss_train')
plt.plot(epochs, losses_val, label='Loss_test')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()

# 绘制accuracy曲线图
plt.figure()
plt.plot(epochs, accuracy_train, label='Accuracy_train')
plt.plot(epochs, accuracy_val, label='Accuracy_test')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training Accuracy')
plt.legend()

plt.show()
